{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "7629f8f1",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-06-08T14:17:31.975342Z",
     "start_time": "2023-06-08T14:17:31.968486Z"
    }
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "from transformers import pipeline, set_seed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "69bacc12",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-06-08T14:11:33.878433Z",
     "start_time": "2023-06-08T14:11:33.873279Z"
    }
   },
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['HTTP_PROXY'] = 'http://127.0.0.1:7890'\n",
    "os.environ['HTTPS_PROXY'] = 'http://127.0.0.1:7890'"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "247ceaee",
   "metadata": {},
   "source": [
    "## tasks"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1a5a4758",
   "metadata": {},
   "source": [
    "- NLP\n",
    "    - `fill-mask`\n",
    "        - `pipe = pipeline(\"fill-mask\", model=\"bert-base-uncased\")`\n",
    "    - 文本分类：`pipeline(\"text-classification\")`\n",
    "        - sentiment analysis\n",
    "    - ner：`pipeline(\"ner\", aggregation_strategy=\"simple\")`\n",
    "        - tagger\n",
    "    - qa: `pipeline(\"question-answering\")`\n",
    "        - reader\n",
    "    - Summarization: `pipeline(\"summarization\")`\n",
    "    - Translation: \n",
    "        ```\n",
    "        translator = pipeline(\"translation_en_to_de\", model=\"Helsinki-NLP/opus-mt-en-de\")\n",
    "        ```\n",
    "    - Text Generation：`generator = pipeline(\"text-generation\")`\n",
    "    \n",
    "- CV\n",
    "    - `object-detection`\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "00ff537c",
   "metadata": {},
   "source": [
    "## 参数"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c7a1578d",
   "metadata": {},
   "source": [
    "- `device`\n",
    "    - `(int, optional, defaults to -1)`\n",
    "    - 前向过程（推理）时，不需要显式地讲 query/input 转到对应的设备上；"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "10faaab3",
   "metadata": {},
   "source": [
    "## examples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "f61376de",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-06-08T14:13:51.401746Z",
     "start_time": "2023-06-08T14:13:47.853235Z"
    },
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of OpenAIGPTLMHeadModel were not initialized from the model checkpoint at openai-gpt and are newly initialized: ['position_ids']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    }
   ],
   "source": [
    "gpt_gen = pipeline('text-generation', model='openai-gpt')\n",
    "gpt2_gen = pipeline('text-generation', model='gpt2')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "1378bd53",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-06-08T14:15:39.475139Z",
     "start_time": "2023-06-08T14:15:39.466275Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(transformers.models.openai.modeling_openai.OpenAIGPTPreTrainedModel,)"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gpt_gen.model.__class__.__bases__"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "619dc63e",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-06-08T14:18:01.589336Z",
     "start_time": "2023-06-08T14:18:01.583780Z"
    }
   },
   "outputs": [],
   "source": [
    "def model_size(model: nn.Module):\n",
    "    return sum(para.numel() for para in model.parameters())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "063166c8",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-06-08T14:20:40.213337Z",
     "start_time": "2023-06-08T14:20:40.202333Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "116.53M\n",
      "124.44M\n"
     ]
    }
   ],
   "source": [
    "print(f'{model_size(gpt_gen.model)/1000**2:.2f}M')\n",
    "print(f'{model_size(gpt2_gen.model)/1000**2:.2f}M')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "903dfaf0",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
